{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 8.3 自动并行计算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-10T16:16:41.669018Z",
     "start_time": "2019-05-10T16:16:36.457355Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import time\n",
    "\n",
    "assert torch.cuda.device_count() >= 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-10T16:17:29.013953Z",
     "start_time": "2019-05-10T16:16:41.673871Z"
    }
   },
   "outputs": [],
   "source": [
    "x_gpu1 = torch.rand(size=(100, 100), device='cuda:0')\n",
    "x_gpu2 = torch.rand(size=(100, 100), device='cuda:2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-10T16:17:29.021652Z",
     "start_time": "2019-05-10T16:17:29.017222Z"
    }
   },
   "outputs": [],
   "source": [
    "class Benchmark():  # 本类已保存在d2lzh_pytorch包中方便以后使用\n",
    "    def __init__(self, prefix=None):\n",
    "        self.prefix = prefix + ' ' if prefix else ''\n",
    "\n",
    "    def __enter__(self):\n",
    "        self.start = time.time()\n",
    "\n",
    "    def __exit__(self, *args):\n",
    "        print('%stime: %.4f sec' % (self.prefix, time.time() - self.start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-10T16:17:29.069210Z",
     "start_time": "2019-05-10T16:17:29.023602Z"
    }
   },
   "outputs": [],
   "source": [
    "def run(x):\n",
    "    for _ in range(20000):\n",
    "        y = torch.mm(x, x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-10T16:17:29.767144Z",
     "start_time": "2019-05-10T16:17:29.071262Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run on GPU1. time: 0.2989 sec\n",
      "Then run on GPU2. time: 0.3518 sec\n"
     ]
    }
   ],
   "source": [
    "with Benchmark('Run on GPU1.'):\n",
    "    run(x_gpu1)\n",
    "    torch.cuda.synchronize()\n",
    "\n",
    "with Benchmark('Then run on GPU2.'):\n",
    "    run(x_gpu2)\n",
    "    torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-05-10T16:17:30.282318Z",
     "start_time": "2019-05-10T16:17:29.770313Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run on both GPU1 and GPU2 in parallel. time: 0.5076 sec\n"
     ]
    }
   ],
   "source": [
    "with Benchmark('Run on both GPU1 and GPU2 in parallel.'):\n",
    "    run(x_gpu1)\n",
    "    run(x_gpu2)\n",
    "    torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:py36]",
   "language": "python",
   "name": "conda-env-py36-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
